#!/bin/sh
_=[[
# vim: filetype=lua ts=2 sts=2 sw=2 et ai
if command -v luajit >/dev/null 2>&1; then
  exec luajit "$0" "$@"
fi
if command -v lua >/dev/null 2>&1; then
  exec lua "$0" "$@"
fi
echo 'requires lua or luajit in $PATH'
exit 1
]]

local app = (arg[0]):match(".-([^%/]+)$")
local script_dir = (arg[0]):match("(.-)[^%/]+$")
package.path = script_dir .. "/../lib/lua/?.lua"
local cargs = require('cargs')
local yaml = require('yaml')
local VERSION = '@MEKONG_VERSION@'

local function printf(msg, ...)
  print(string.format(msg, ...))
end

local function logf(msg, ...)
  io.stderr:write(string.format(msg, ...) .. "\n")
end

local function check_types(...)
  local level = 2
  local info = debug.getinfo(level, "nSl")
  local i = 1
  local type_spec = {...}
  while true do
    local name, value = debug.getlocal(level, i)
    if not name then break end

    if type(value) ~= type_spec[i] then
      error(string.format("%s (%s:%d) expected %s for arg '%s' (%d), got %s",
        info.name, info.short_src, info.linedefined, type_spec[i], name, i, type(value)))
    end
    i = i + 1
  end
end

-------------------------------------------------------------------------------
-- pattern matchers / substitutors
-------------------------------------------------------------------------------

local function substitute_kernel_launch(str, fn)
  -- syntax: $IDENTIFIER \s* <<< $CONFIG >>> \s* ( $PARAMETERS )
  local PATTERN = "(".."([_a-zA-Z][_a-zA-Z0-9]*)" .. -- start + identifier
                  "%s*<<<%s*(.-)%s*>>>%s*" ..        -- kernel grid configuration
                  "(%b())"..")%s*;"                   -- balanced parentheses (contains kernel arguments) + stop
  local CHECK = "[^()%[%]{}]"
  str, num = string.gsub(str, PATTERN, function(whole, name, grid, c_args)
    c_args = string.sub(c_args, 2, -2) -- remove parentheses arounds kernel arguments
    -- check if there are any forms of braces in any of the lists
    assert(string.match(grid, CHECK), "unsupported kernel grid configuration")
    assert(c_args == "" or string.match(c_args, CHECK), "unsupported kernel arguments")

    -- split arguments to gridConfig into list
    local grid = string.split(grid, ",")
    -- split arguments to kernel into list
    local c_args = string.split(c_args, ",")

    return fn(whole, name, grid, c_args)
  end)
  --logf("%d kernel substitutions", num)
  return str
end

local function substitute_kernel_decl(str, fn)
  -- syntax: __global__%s+ $TYPE %s+ $NAME +%s ( $PARAMETERS )
  local ID = "[_a-zA-Z][_a-zA-Z0-9]*"
  local PATTERN = "(".."__global__%s+"..ID.."%s+("..ID..")%s*(%b())"..")"
  str, num = string.gsub(str, PATTERN, function(whole, name, c_args)
    c_args = string.sub(c_args, 2, -2) -- remove parentheses arounds kernel arguments
    -- kernel argument list can not contain commas, so no check necessary
    c_args = string.split(c_args, ",")
    return fn(whole, name, c_args)
  end)
  return str
end

local function substitute_include(str, fn)
  return fn("", "") .. str
end

-------------------------------------------------------------------------------
-- name lookup / translation / querys
-------------------------------------------------------------------------------

-- translate typename from llvm ir to C
local function irttoc(llvmname)
  check_types('string')
  -- basic types
  local simple = {
    ["i8"] = "char",
    ["i16"] = "int16_t",
    ["i32"] = "int32_t",
    ["i64"] = "int64_t",
    ["half"] = "half",
    ["float"] = "float",
    ["double"] = "double",
    ["fp128"] = "__float128",
  }
  if simple[llvmname] then
    return simple[llvmname]
  end
  -- simple pointer types
  local ptrt = string.match(llvmname, "^([_a-zA-Z0-9]+)[*]$")
  if ptrt and simple[ptrt] then
    return simple[ptrt] .. "*"
  end
  -- struct pointer types
  local strptrt = string.match(llvmname, "^%%struct.([_a-zA-Z0-9]+[*])$")
  if strptrt then
    return strptrt 
  end
  error("unsupported llvm type: " .. llvmname)
end

local function partitioned_kernel_name(kernel)
  check_types('string')
  return string.format("__%s_subgrid", kernel)
end

local function partitioned_kernel_iterator(kernel, array, mode)
  check_types('string', 'number', 'string')
  assert(mode == "read" or mode == "write", "only read and write iterators supported")
  return string.format("__%s_%d_%s", kernel, array, mode)
end

local function partition_fn_for(scheme)
  check_types('string')
  if scheme == "linear:x" then
    return "__me_partition_linear_x"
  elseif scheme == "linear:y" then
    return "__me_partition_linear_y"
  elseif scheme == "linear:z" then
    return "__me_partition_linear_z"
  else
    error("invalid partitioning scheme: " .. scheme)
  end
end

local function filter(values, test)
  check_types('table', 'function')
  local res = {}
  for _, v in ipairs(values) do
    if test(v) then res[#res+1] = v end
  end
  return res
end

-- translates an argument passed to a kernel launch depending on wether it is
-- a pointer (= virtual buffer).
-- c_expr:      c expression passed to kernel launch
-- argument:   argument info from kernel analysis
-- c_instance: c expression with instance number
local function translate_argument(c_expr, argument, c_instance)
  check_types('string', 'table', 'string')
  if not argument['is-pointer'] then
    return c_expr
  else
    local c_type = irttoc(argument['type-name'])
    return string.format("(%s)__me_nth_array(%s, %s)", c_type, c_expr, c_instance)
  end
end

-------------------------------------------------------------------------------
-- templating engine
-------------------------------------------------------------------------------

-- split string by ','
string.split = function(self, sep)
  local sep, fields = sep or ',', {}
  local pattern = string.format("%%s*([^%s]+)%%s*", sep)
  string.gsub(self, pattern, function(c) fields[#fields+1] = c end)
  return fields
end

-- template strings
string.template = function(self, env)
  -- {{fn ... }} -> interprets text after 'fn' as a function body.
  -- The function is executed and can use write(...) and print(...)
  -- to generate output that replaces the expression.
  self, _ = string.gsub(self, " *{{fn%s+(.-)}}", function(fnstr)
    local fn = assert(loadstring(fnstr))
    local buf = ""
    local function fakeWrite(...)
      for _, val in ipairs({...}) do
        buf = buf .. tostring(val)
      end
    end
    local fnenv = setmetatable({
      write = fakeWrite,
      print = function(...) fakeWrite(...) fakeWrite("\n") end,
      pairs = pairs,
      ipairs = ipairs,
    }, { __index = env })
    setfenv(fn, fnenv)
    fn()
    return buf
  end)
  -- {{ expr }} -> replace expression by the string ... is evaluated to.
  self, _ = string.gsub(self, "{{(.-)}}", function(m)
    local fn = assert(loadstring("return " .. m))
    setfenv(fn, env)
    return tostring(fn())
  end)
  return self
end

-------------------------------------------------------------------------------
-- CLI
-------------------------------------------------------------------------------

local inputfiles = {}
local kernelinfos = {}
local output = nil

local app = cargs.App:new('mekongrw', 'The mekong rewriter.', '<input file>')
app:add_option('joinsep', 'once', {'-'}, 'o', nil, 'The output file.', '<file>')
app:add_option('joined', 'list', {'-', '--'}, 'info=', 'info', 'YAML analysis for the input file.', '<file>')
app:add_option('flag', 'last', {'-', '--'}, 'version', nil, 'Show version and exit.')
opts, args = app:parse()

if opts['version'] then
  print('version: ' .. VERSION)
  os.exit(0)
end

inputfiles = args or {}
kernelinfos = opts['info'] or {}
output = opts['o'] or '-'

if #kernelinfos < 1 then
  logf("warning: no kernel info files")
end

-------------------------------------------------------------------------------
-- Program logic
-------------------------------------------------------------------------------


-- Load input files

local contents = ""

for _, filename in ipairs(inputfiles) do
  local f, err = io.open(filename, "r")
  if err ~= nil then
    logf("error opening file '%s': %s", filename, err)
    os.exit(1)
  end
  contents = contents .. f:read("*a")
end

-- Load kernel infos

local kernels = {}
for _, filename in ipairs(kernelinfos) do
  local fcontents = io.open(filename):read('*a')
  local info = yaml.eval(fcontents)
  if info.kernels ~= "..." then
    for idx, k in ipairs(info.kernels) do
      kernels[k['name']] = k
    end
  end
end

-- set default values for kernels.
-- required because lua can not automatically infer that missing entries in
-- yaml might be arrays/maps
for _, kernel in pairs(kernels) do
  kernel['arguments'] = kernel['arguments'] or {}
end

-------------------------------------------------------------------------------
-- Substitutions
-------------------------------------------------------------------------------

contents = substitute_kernel_launch(contents, function(orig, name, grid, c_args)
  local kernel = kernels[name]

  if kernel == nil then
    logf("Kernel analysis not available for kernel %s", name)
    os.exit(1)
  end

  -- prepare template environment
  local env = {
    name = name,
    grid = grid,
    c_args = c_args,
    kernel = kernel,
    fmt = string.format,
    join = table.concat,

    filter = filter,
    _kernel_name = partitioned_kernel_name, 
    _kernel_iterator = partitioned_kernel_iterator,
    partition_fn_for = partition_fn_for,
    translate_argument = translate_argument,

    is_parameter = function(argument) return argument['is-parameter'] end,
    is_read = function(argument) return argument['read-map'] ~= '' end,
    is_write = function(argument) return argument['write-map'] ~= '' end,

  }

  -- splittable
  if kernel['partitioning'] ~= 'none' then

    return string.template([[

  // ++++++++ PARTITIONED KERNEL {{name}} <<< {{join(grid, ", ")}} >>> ({{join(c_args, ", ")}})
    {
      dim3 _me_grid = {{ grid[1] }};
      dim3 _me_block = {{ grid[2] }};
      size_t _me_shared = {{ grid[3] or 0 }};
      cudaStream_t _me_stream = {{ grid[4] or 0 }};

      __subgrid_t _me_subgrid;
      int64_t _me_params[{{ #filter(kernel['arguments'], is_parameter) }}];
      _me_subgrid.full.zdim = _me_block.z;
      _me_subgrid.full.ydim = _me_block.y;
      _me_subgrid.full.xdim = _me_block.x;
      int _me_ngpus = __me_num_gpus();

      // populate params for polyhedral model
      {{fn
        local currIdx = 0
        for idx, argument in ipairs(kernel['arguments'], is_parameter) do
          if is_parameter(argument) then
            print(fmt("      _me_params[%d] = (int64_t)%s;", currIdx, c_args[idx]))
            currIdx = currIdx + 1
          end
      end}}
      // synchronize arrays read arrays
      for (int _me_i = 0; _me_i < _me_ngpus; ++_me_i) {
        {{ partition_fn_for(kernel['partitioning']) }}(&_me_subgrid, _me_i, _me_ngpus, _me_grid);
        __me_itfn_t _me_iter = NULL;
        int _me_elsize = 0;
        {{fn for idx, argument in ipairs(kernel['arguments']) do
          if is_read(argument) then 
            print(fmt("      // synchronize reads for array %s", c_args[idx]))
            print(fmt("      _me_iter = %s;", _kernel_iterator(name, idx-1, 'read')))
            print(fmt("      _me_elsize = %s / 8;", argument['element-bitsize']))
            print(fmt("      __me_buffer_sync(%s, _me_i, _me_iter, _me_elsize, &_me_subgrid, _me_params);", c_args[idx]))
          end
        end}}
      }
      // sync so buffers are up to date
      __me_sync();

      // launch kernels
      for (int _me_i = 0; _me_i < _me_ngpus; ++_me_i) {
        {{ partition_fn_for(kernel['partitioning']) }}(&_me_subgrid, _me_i, _me_ngpus, _me_grid);
        unsigned int z = _me_subgrid.kernel.zmax - _me_subgrid.kernel.zmin;
        unsigned int y = _me_subgrid.kernel.ymax - _me_subgrid.kernel.ymin;
        unsigned int x = _me_subgrid.kernel.xmax - _me_subgrid.kernel.xmin;
        cudaSetDevice(_me_i);
        {{ _kernel_name(name) }}<<< dim3(x, y, z), _me_block, _me_shared, _me_stream >>>
          ({{fn for idx, arg in ipairs(c_args) do
             write(translate_argument(arg, kernel['arguments'][idx], "_me_i"))
             write(", ")
           end}}_me_subgrid.kernel);
      }

      // update writes (parallel to kernel execution)
      for (int _me_i = 0; _me_i < _me_ngpus; ++_me_i) {
        {{ partition_fn_for(kernel['partitioning']) }}(&_me_subgrid, _me_i, _me_ngpus, _me_grid);
        __me_itfn_t _me_iter = NULL;
        int _me_elsize = 0;
        {{fn for idx, argument in ipairs(kernel['arguments']) do
          if is_write(argument) then
            print(fmt("      // update writes for array %s", c_args[idx]))
            print(fmt("      _me_iter = %s;", _kernel_iterator(name, idx-1, 'write')))
            print(fmt("      _me_elsize = %s / 8;", argument['element-bitsize']))
            print(fmt("      __me_buffer_update(%s, _me_i, _me_iter, _me_elsize, &_me_subgrid, _me_params);", c_args[idx]))
          end
        end}}
      }
      cudaSetDevice(0);
      // wait for kernels to finish (necessary?)
      //__me_sync();
    }
  // -------- {{name}} <<< {{join(grid, ", ")}} >>> ({{join(c_args, ", ")}})
    ]], env)

  -- unsplittable
  else 

    logf("unsplittable kernel '%s'", name)

    return string.template([[
  // ++++++++ UNPARTITIONED KERNEL {{name}} <<< {{join(grid, ", ")}} >>> ({{join(c_args, ", ")}})
    {
      {{fn for idx, argument in ipairs(kernel['arguments']) do
        if is_read(argument) then 
          print(fmt("      // synchronize reads for array %s", c_args[idx]))
          print(fmt("      __me_buffer_sync_all(%s, 0);", c_args[idx]));
        end
      end}}
      __me_sync();

      cudaSetDevice(0);
      {{ name }}<<< {{grid[1]}}, {{grid[2]}}, {{grid[3] or "0"}}, {{grid[4] or "0"}} >>>
        ({{fn for idx, arg in ipairs(c_args) do
          if idx > 1 then write(", ") end
          write(translate_argument(arg, kernel['arguments'][idx], "0"))
        end}});

      {{fn for idx, argument in ipairs(kernel['arguments']) do
        if is_write(argument) then 
          print(fmt("      // synchronize reads for array %s", c_args[idx]))
          print(fmt("      __me_buffer_update_all(%s, 0);", c_args[idx]));
        end
      end}}
    }
  // -------- UNPARTITIONED KERNEL {{name}} <<< {{join(grid, ", ")}} >>> ({{join(c_args, ", ")}})
    ]], env)
  end
end)

-------------------------------------------------------------------------------
contents = substitute_include(contents, function(_, _)
  local decls = {}
  for name, kernel in pairs(kernels) do
    if kernel['partitioning'] ~= 'none' then
      local kernelDecl = string.format("extern \"C\" __global__ void %s(", partitioned_kernel_name(name))
      for _, arg in ipairs(kernel['arguments']) do
        kernelDecl = kernelDecl .. irttoc(arg['type-name']) .. ", "
      end
      kernelDecl = kernelDecl .. " __subgrid_kernel_t) {}"
      decls[#decls+1] = kernelDecl
      for idx, arg in ipairs(kernel['arguments']) do
        if arg['is-pointer'] then
          decls[#decls+1] = string.format("extern \"C\" void %s(int64_t grid[], int64_t param[], __me_cbfn_t, void*)",
          partitioned_kernel_iterator(name, idx-1, "read"));
          decls[#decls+1] = string.format("extern \"C\" void %s(int64_t grid[], int64_t param[], __me_cbfn_t, void*)",
          partitioned_kernel_iterator(name, idx-1, "write"));
        end
      end
    end
  end
  local env = {
    orig = orig,
    decls = decls,
  }
  return string.template([[
///// start mekong header
#include <stdint.h>
#include "me-runtime.h"

// auto generated iterators etc

{{fn for _, decl in ipairs(decls) do
  print(decl, ";")
end}}

///// end mekong header

]], env)
end)

-------------------------------------------------------------------------------
if output == '-' then
  print(contents)
else
  local f, err = io.open(output, "w")
  if f~=nil then
    f:write(contents)
    f:write("\n")
    io.close(f)
  else
    logf("unable to open output file: %s", err)
    os.exit(1)
  end
end
